/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef EXAMPLES_PAD_BROADCAST_CUSTOM_H
#define EXAMPLES_PAD_BROADCAST_CUSTOM_H
#include "kernel_operator.h"

constexpr int32_t BUFFER_NUM = 1;
template <typename T, int32_t dim, int32_t axis>
class KernelBroadcastCustom {
public:
    __aicore__ inline KernelBroadcastCustom()
    {}
    __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, uint32_t srcLength, uint32_t dstLength,
        const uint32_t srcShape[dim], const uint32_t dstShape[dim])
    {
        AscendC::AscendCUtils::SetOverflow(1);
        xGm.SetGlobalBuffer((__gm__ T *)x, srcLength);
        yGm.SetGlobalBuffer((__gm__ T *)y, dstLength);

        pipe.InitBuffer(inQueueX, BUFFER_NUM, srcLength * sizeof(T));
        pipe.InitBuffer(outQueueY, BUFFER_NUM, dstLength * sizeof(T));

        srcLength_ = srcLength;
        dstLength_ = dstLength;
        srcShape_ = srcShape;
        dstShape_ = dstShape;
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<T> xLocal = inQueueX.AllocTensor<T>();
        AscendC::DataCopy(xLocal, xGm, srcLength_);
        inQueueX.EnQue(xLocal);
    }
    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<T> xLocal = inQueueX.DeQue<T>();
        AscendC::LocalTensor<T> yLocal = outQueueY.AllocTensor<T>();
        AscendC::BroadCast<T, dim, axis>(yLocal, xLocal, dstShape_, srcShape_);

        outQueueY.EnQue<T>(yLocal);
        inQueueX.FreeTensor(xLocal);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<T> yLocal = outQueueY.DeQue<T>();
        AscendC::DataCopy(yGm, yLocal, dstLength_);
        outQueueY.FreeTensor(yLocal);
    }

private:
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> inQueueX;
    AscendC::TQue<AscendC::TPosition::VECOUT, BUFFER_NUM> outQueueY;
    AscendC::GlobalTensor<T> xGm;
    AscendC::GlobalTensor<T> yGm;
    uint32_t srcLength_;
    uint32_t dstLength_;
    const uint32_t *srcShape_{nullptr};
    const uint32_t *dstShape_{nullptr};
};
#endif // EXAMPLES_PAD_BROADCAST_CUSTOM_H